{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CrypTen - Private Model Inference using Plans\n",
    "\n",
    "\n",
    "We will do a simple inference on an encrypted neural network that is not known by the local worker.\n",
    "The workers that know the model structure are deployed as [Grid Nodes](https://github.com/OpenMined/PyGrid/tree/dev/apps/node). For this we will be using Plans and we will be using CrypTen as a backend for SMPC. \n",
    "\n",
    "\n",
    "Authors:\n",
    " - George Muraru - Twitter: [@gmuraru](https://twitter.com/georgemuraru)\n",
    " - Ayoub Benaissa - Twitter: [@y0uben11](https://twitter.com/y0uben11)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference Overivew\n",
    "* In this tutorial we will a subset of the the [MNIST dataset](http://yann.lecun.com/exdb/mnist/)\n",
    "* The pre-trained model will be hosted on another worker."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "### Download/install needed repos\n",
    "* Clone the [PyGrid repo](https://github.com/OpenMined/PyGrid)\n",
    "  * we need this because *alice* and *bob* are two different Nodes in our network\n",
    "  * install the PyGrid node component using *poetry*\n",
    "\n",
    "### Bring up the PyGridNodes\n",
    "* In the *PyGrid* repo:\n",
    " 1. install *poetry* (```pip install poetry```)\n",
    " 2. go to *apps/nodes*\n",
    " 3. run ```poetry install``` (those steps are also in the README from the PyGrid repo)\n",
    " 4. start *bob* and *alice* using:\n",
    " ```\n",
    " ./run.sh --id alice --port 3000 --start_local_db\n",
    " ./run.sh --id bob --port 3001 --start_local_db\n",
    " ```\n",
    " \n",
    "This will start two workers, *alice* and *bob* and we will connect to them using the port 3000 and 3001."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-08-06 00:04:44--  https://raw.githubusercontent.com/facebookresearch/CrypTen/b1466440bde4db3e6e1fcb1740584d35a16eda9e/tutorials/mnist_utils.py\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.112.133\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.112.133|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 7401 (7.2K) [text/plain]\n",
      "Saving to: ‘mnist_utils.py’\n",
      "\n",
      "mnist_utils.py      100%[===================>]   7.23K  --.-KB/s    in 0.005s  \n",
      "\n",
      "2020-08-06 00:04:44 (1.32 MB/s) - ‘mnist_utils.py’ saved [7401/7401]\n",
      "\n",
      "--2020-08-06 00:04:44--  https://github.com/facebookresearch/CrypTen/blob/master/tutorials/models/tutorial4_alice_model.pth?raw=true\n",
      "Resolving github.com (github.com)... 140.82.118.3\n",
      "Connecting to github.com (github.com)|140.82.118.3|:443... connected.\n",
      "HTTP request sent, awaiting response... 302 Found\n",
      "Location: https://github.com/facebookresearch/CrypTen/raw/master/tutorials/models/tutorial4_alice_model.pth [following]\n",
      "--2020-08-06 00:04:45--  https://github.com/facebookresearch/CrypTen/raw/master/tutorials/models/tutorial4_alice_model.pth\n",
      "Reusing existing connection to github.com:443.\n",
      "HTTP request sent, awaiting response... 302 Found\n",
      "Location: https://raw.githubusercontent.com/facebookresearch/CrypTen/master/tutorials/models/tutorial4_alice_model.pth [following]\n",
      "--2020-08-06 00:04:45--  https://raw.githubusercontent.com/facebookresearch/CrypTen/master/tutorials/models/tutorial4_alice_model.pth\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.112.133\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.112.133|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 474857 (464K) [application/octet-stream]\n",
      "Saving to: ‘alice_pretrained_model.pth’\n",
      "\n",
      "alice_pretrained_mo 100%[===================>] 463.73K  1.93MB/s    in 0.2s    \n",
      "\n",
      "2020-08-06 00:04:46 (1.93 MB/s) - ‘alice_pretrained_model.pth’ saved [474857/474857]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget \"https://raw.githubusercontent.com/facebookresearch/CrypTen/b1466440bde4db3e6e1fcb1740584d35a16eda9e/tutorials/mnist_utils.py\" -O \"mnist_utils.py\"\n",
    "!wget \"https://github.com/facebookresearch/CrypTen/blob/master/tutorials/models/tutorial4_alice_model.pth?raw=true\" -O \"alice_pretrained_model.pth\"\n",
    "!python ./mnist_utils.py --option train_v_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare the ground"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import crypten\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from time import time\n",
    "\n",
    "import syft as sy\n",
    "\n",
    "from syft.frameworks.crypten.model import OnnxModel\n",
    "from syft.grid.clients.data_centric_fl_client import DataCentricFLClient\n",
    "from syft.grid.clients.model_centric_fl_client import ModelCentricFLClient\n",
    "from syft.frameworks.crypten.context import run_multiworkers\n",
    "\n",
    "torch.manual_seed(0)\n",
    "torch.set_num_threads(1)\n",
    "hook = sy.TorchHook(torch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Neural network that will be known only to the workers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Alice* has the pre-trained version of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AliceNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(AliceNet, self).__init__()\n",
    "        self.fc1 = nn.Linear(784, 128)\n",
    "        self.fc2 = nn.Linear(128, 128)\n",
    "        self.fc3 = nn.Linear(128, 10)\n",
    " \n",
    "    def forward(self, x):\n",
    "        out = self.fc1(x)\n",
    "        out = F.relu(out)\n",
    "        out = self.fc2(out)\n",
    "        out = F.relu(out)\n",
    "        out = self.fc3(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup the workers and send them the neural network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to have the two GridNodes workers running.\n",
    "\n",
    "In our current scenario we are sending the serialized model to the workers that are taking part in the computation, but in a real life situation this sending part should not exist - we only need to know that we have the same model on all the workers.\n",
    "\n",
    "### Scenario\n",
    "* The local worker wants to run inference on the data that is hosted on *bob* machine.\n",
    "* The model structure is known only by *alice* and *bob*\n",
    "* *Alice* has the pre-trained network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[%] Connecting to workers...\n",
      "[+] Connected to workers\n",
      "[%] Create the serialized model...\n",
      "[%] Sending the serialized pre-trained model...\n",
      "[+] Serialized model sent to Alice\n",
      "[%] Sending the serialized model...\n",
      "[+] Serialized model sent to Bob\n",
      "[%] Send test data to bob...\n",
      "[+] Data sent to bob\n",
      "[%] Load labels...\n",
      "[+] Labels loaded\n"
     ]
    }
   ],
   "source": [
    "# Syft workers\n",
    "print(\"[%] Connecting to workers...\")\n",
    "alice = DataCentricFLClient(hook, \"ws://localhost:3000\")\n",
    "bob = DataCentricFLClient(hook, \"ws://localhost:3001\")\n",
    "print(\"[+] Connected to workers\")\n",
    "\n",
    "print(\"[%] Create the serialized model...\")\n",
    "dummy_input = torch.empty((1, 784))\n",
    "pytorch_model = AliceNet()\n",
    "\n",
    "# Alice has the model with the real weights\n",
    "print(\"[%] Sending the serialized pre-trained model...\")\n",
    "model_pretrained = torch.load('alice_pretrained_model.pth')\n",
    "model_alice = OnnxModel.fromModel(model_pretrained, dummy_input).tag(\"crypten_model\")\n",
    "alice_model_ptr = model_alice.send(alice)\n",
    "print(\"[+] Serialized model sent to Alice\")\n",
    "    \n",
    "print(\"[%] Sending the serialized model...\")\n",
    "model = OnnxModel.fromModel(pytorch_model, dummy_input).tag(\"crypten_model\")\n",
    "bob_model_ptr = model.send(bob)\n",
    "print(\"[+] Serialized model sent to Bob\")\n",
    "    \n",
    "print(\"[%] Send test data to bob...\")\n",
    "data = torch.load('/tmp/bob_test.pth')\n",
    "data_ptr_bob = data.tag(\"crypten_data\").send(bob)\n",
    "print(\"[+] Data sent to bob\")\n",
    "\n",
    "print(\"[%] Load labels...\")\n",
    "labels = torch.load('/tmp/bob_test_labels.pth').long()\n",
    "print(\"[+] Labels loaded\")\n",
    "\n",
    "\n",
    "# Function used to compute the accuracy for the model\n",
    "# Taken from CrypTen repository\n",
    "def compute_accuracy(output, labels):\n",
    "    pred = output.argmax(1)\n",
    "    correct = pred.eq(labels)\n",
    "    correct_count = correct.sum(0, keepdim=True).float()\n",
    "    accuracy = correct_count.mul_(100.0 / output.size(0))\n",
    "    return accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the CrypTen computation\n",
    "\n",
    "We need to specify for the ```run_multiworkers``` decorater:\n",
    "* the workers that will take part in the computation\n",
    "* the master address, this will be used for communication\n",
    "\n",
    "We will use the ```func2plan``` decorator to:\n",
    "* trace the operations from our function\n",
    "* sending the plan operations to *alice* and *bob* - the plans operations will act as the function\n",
    "* run the plans operations on both workers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "ALICE = 0 # Alice rank in CrypTen\n",
    "BOB = 1 # Bob rank in CrypTen\n",
    "\n",
    "@run_multiworkers([alice, bob], master_addr=\"127.0.0.1\")\n",
    "@sy.func2plan()\n",
    "def run_encrypted_inference(crypten=crypten):\n",
    "    data_enc = crypten.load(\"crypten_data\", BOB)\n",
    "    \n",
    "    data_enc2 = data_enc[:100]\n",
    "    data_flatten = data_enc2.flatten(start_dim=1)\n",
    "    \n",
    "    # This should load the crypten model that is found at all parties\n",
    "    model = crypten.load_model(\"crypten_model\")\n",
    "\n",
    "    model.encrypt(src=ALICE)\n",
    "    model.eval()\n",
    "    \n",
    "    result_enc = model(data_flatten)\n",
    "    result = result_enc.get_plain_text()\n",
    "    \n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the CrypTen computation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's run the inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[%] Starting computation\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[+] Computation finished\n",
      "The accuracy is 99.0\n"
     ]
    }
   ],
   "source": [
    "# Get the returned values\n",
    "# key 0 - return values for alice\n",
    "# key 1 - return values for bob\n",
    "print(\"[%] Starting computation\")\n",
    "result = run_encrypted_inference()[1]\n",
    "print(\"[+] Computation finished\")\n",
    "\n",
    "accuracy = compute_accuracy(result, labels[:100])\n",
    "print(f\"The accuracy is {accuracy.item()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CleanUp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CleanUp portion taken from the CrypTen project\n",
    "\n",
    "import os\n",
    "\n",
    "\n",
    "filenames = ['/tmp/alice_train.pth', \n",
    "             '/tmp/alice_train_labels.pth', \n",
    "             '/tmp/bob_test.pth', \n",
    "             '/tmp/bob_test_labels.pth',\n",
    "             'alice_pretrained_model.pth',\n",
    "             'mnist_utils.py']\n",
    "\n",
    "for fn in filenames:\n",
    "    if os.path.exists(fn): os.remove(fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations!!! - Time to Join the Community!\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!\n",
    "\n",
    "### Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for \"Projects\". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more \"one off\" mini-projects by searching for GitHub issues marked \"good first issue\".\n",
    "\n",
    "- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+first+issue+%3Amortar_board%3A%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
